% Copyright (C) 2019-2023 Benjamin Born, Francesco D'Ascanio, Gernot J. Mueller, Johannes Pfeifer
%
% This is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
%
% It is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
% 
% For a copy of the GNU General Public License,
% see <http://www.gnu.org/licenses/>.

% Compute impulse responses to positive and negative g shocks
clear;
rng(1)
if ~isfolder('Figures')
    mkdir('.','Figures');
end
if ~isfolder('IRF_results')
    mkdir('.','IRF_results');
end

save_name = 'IRF_results/irfs_peg_ergodic_MC.mat';
IRF_save_name = 'Figures/IRF_peg_ergodic_MC' ;

N_periods = 20; %number of simulation periods (quarters)
N_rep = 1e6; %number of replications
T_ergodic = 1e6; %periods for simulation of mean;
T_burnin = 1e+4; %burn-in period

%load policy function
loaded_PFI=load('Policy_functions/pfi_peg_g_low_beta_large_grid.mat');

if ~isfield(loaded_PFI.grid,'w_grid_vec')
    loaded_PFI.grid.w_grid_vec=loaded_PFI.w_grid_vec;
end

omega=loaded_PFI.par.omega;
alfa=loaded_PFI.par.alfa;
hbar=loaded_PFI.par.hbar;
xi=loaded_PFI.par.xi;
if ~isfield(loaded_PFI.par,'phi_e')
    loaded_PFI.par.phi_e= 0; %exchange rate policy parameter
end

%% Set initial conditions of exogenous states, debt, and past wage equal to their respective unconditional means
temp=unique(loaded_PFI.grid.y_level);
y_0 = temp(round(length(temp)/2));
temp=unique(loaded_PFI.grid.r_level);
r_0 = temp(round(length(temp)/2));
temp=unique(loaded_PFI.grid.g_level);
g_0 = temp(round(length(temp)/2));
g_0_pos = loaded_PFI.grid.g_level(find(loaded_PFI.grid.g_level==g_0,1,'last')+1);
g_0_neg = loaded_PFI.grid.g_level(find(loaded_PFI.grid.g_level==g_0, 1 )-1);
d_0 = loaded_PFI.grid.d_grid(307);
w_0 = loaded_PFI.grid.w_grid_vec(226);
h_0 = loaded_PFI.h(round(loaded_PFI.grid.n_y/2),307,226);


%store initial values
y_T_start = repmat(y_0,[1 3]);
r_start = repmat(r_0,[1 3]);
g_N_start = [g_0 g_0_pos g_0_neg];
d_start = repmat(d_0,[1 3]);
w_lag_start = repmat(w_0,[1 3]);
h_start = repmat(h_0,[1 3]);

c_T = y_0 + d_0./(1+r_0) - d_0;
y_N = h_0.^alfa; %nontradable output
c_N = y_0 - g_0; %nontradable consumption
p_lag_start = (1-omega)/omega*(c_N./c_T).^(-1/xi); %relative price of nontradables in terms of tradables
p_lag_start = repmat(p_lag_start,[1 3]);
clear('c_T','y_N','c_N','r_0','g_0','d_0','g_0_pos','g_0_neg','h_0');

%P_trans is the transition probability matrix of tradable output. Cumul_P_trans is the cumulative probability matrix (useful for drawing tradable output realizations)
Cum_prob = cumsum(loaded_PFI.MC.P_trans,2);
[Y_base, var_name, pos]=simulate_model_general(y_T_start(1),r_start(1),g_N_start(1),d_start(1),w_lag_start(1),p_lag_start(1),T_ergodic,loaded_PFI.dp,loaded_PFI.h,loaded_PFI.grid,loaded_PFI.par,Cum_prob,1);
Y_base_mean=mean(Y_base(T_burnin+1:end,:));
%get periods of unemployment
T_no_slack=find(Y_base(:,pos.unemployment)<0.01);
N_IRF_points=1e5;
T_no_slack(T_no_slack==1)=[]; %make sure first period is not selected;
T_IRFs=T_no_slack(randi(length(T_no_slack)-1,1,N_IRF_points)); %no slack in last period


nvar = length(fieldnames(pos)); %number of variables for which IR are computed

%Initialize vector of impulse responses
Y = zeros(N_periods,nvar,N_IRF_points);
Y_pos = zeros(N_periods,nvar,N_IRF_points);
Y_neg = zeros(N_periods,nvar,N_IRF_points);

%reset random number generator
rng('default')
tic
for replic_iter=1:N_IRF_points
    %set initial condition for current replication
    d_start_ergodic=Y_base(T_IRFs(replic_iter),pos.D);
    w_start_ergodic=Y_base(T_IRFs(replic_iter),pos.w);
    p_start_ergodic=Y_base(T_IRFs(replic_iter),pos.p);
    y_T_start_ergodic = Y_base(T_IRFs(replic_iter)+1,pos.y_T);
    r_start_ergodic = Y_base(T_IRFs(replic_iter)+1,pos.r_ann)/4;
    g_N_start_ergodic = Y_base(T_IRFs(replic_iter)+1,pos.g_N);
    G_index=find(loaded_PFI.g_values==g_N_start_ergodic);
    G_index_neg=max(1,G_index-1);
    G_index_pos=min(loaded_PFI.grid.NG,G_index+1);

    Y(:,:,replic_iter)=simulate_model_general(y_T_start_ergodic,r_start_ergodic,g_N_start_ergodic,...
        d_start_ergodic,w_start_ergodic,p_start_ergodic,N_periods,loaded_PFI.dp,loaded_PFI.h,loaded_PFI.grid,loaded_PFI.par,Cum_prob);
    Y_pos(:,:,replic_iter)=simulate_model_general(y_T_start_ergodic,r_start_ergodic,loaded_PFI.g_values(G_index_pos),...
        d_start_ergodic,w_start_ergodic,p_start_ergodic,N_periods,loaded_PFI.dp,loaded_PFI.h,loaded_PFI.grid,loaded_PFI.par,Cum_prob);
    Y_neg(:,:,replic_iter)=simulate_model_general(y_T_start_ergodic,r_start_ergodic,loaded_PFI.g_values(G_index_neg),...
        d_start_ergodic,w_start_ergodic,p_start_ergodic,N_periods,loaded_PFI.dp,loaded_PFI.h,loaded_PFI.grid,loaded_PFI.par,Cum_prob);
    
end
toc

%% Plot IRFs
IRFs=NaN(nvar,N_periods,2);
h1=figure('Name','Positive shock');
for var_iter=1:min(nvar,16)
    subplot(4,4,var_iter)
    IRFs(var_iter,:,1)=mean(squeeze((Y_pos(:,var_iter,:)-Y(:,var_iter,:))),2);
    plot(IRFs(var_iter,:,1)./Y_base_mean(var_iter))
    xlim([1 N_periods]);
    title(var_name{var_iter}, 'interpreter','latex')
end
print([IRF_save_name '_positive_ergodic_MC'],'-depsc2')
saveas(h1,[IRF_save_name '_positive_ergodic_MC'])

if nvar>16
    h3=figure('Name','Positive shock 2');
    for var_iter=1:min(nvar-16,16)
        subplot(4,4,var_iter)
        IRFs(16+var_iter,:,1)=mean(squeeze((Y_pos(:,16+var_iter,:)-Y(:,16+var_iter,:))),2);
        plot(IRFs(16+var_iter,:,1)./Y_base_mean(16+var_iter))
        xlim([1 N_periods]);
        title(var_name{16+var_iter}, 'interpreter','latex')
    end
end
print([IRF_save_name '_positive_ergodic_MC_2'],'-depsc2')
saveas(h3,[IRF_save_name '_positive_ergodic_MC_2'])


h2=figure('Name','Negative shock');
for var_iter=1:min(nvar,16)
    subplot(4,4,var_iter)
    IRFs(var_iter,:,2)=mean(squeeze((Y_neg(:,var_iter,:)-Y(:,var_iter,:))),2);
    plot(IRFs(var_iter,:,2)./Y_base_mean(var_iter))
    xlim([1 N_periods]);
    title(var_name{var_iter}, 'interpreter','latex')
end
print([IRF_save_name '_negative_ergodic_MC'],'-depsc2')
saveas(h2,[IRF_save_name '_negative_ergodic_MC'])

if nvar>16
    h4=figure('Name','Negative shock 2');
    for var_iter=1:min(nvar-16,16)
        subplot(4,4,var_iter)
        IRFs(16+var_iter,:,2)=mean(squeeze((Y_neg(:,16+var_iter,:)-Y(:,16+var_iter,:))),2);
        plot(IRFs(16+var_iter,:,2)./Y_base_mean(16+var_iter))
        xlim([1 N_periods]);
        title(var_name{16+var_iter}, 'interpreter','latex')
    end    
end
print([IRF_save_name '_negative_ergodic_MC_2'],'-depsc2')
saveas(h4,[IRF_save_name '_negative_ergodic_MC_2'])

save(save_name,'IRFs','Y_base_mean','N_periods','var_name','pos','-v7.3')